#!/usr/bin/env python
import numpy as np
import tensorflow as tf
import data_read as d

import vae_gan_models as VMG
#import vae_gan_models_shallow as VMG
#import vae_gan_models_deep as VMG

import time

n_epochs        =   10000
k_step          =   5
save_period     =   200
batch_size      =   1
beta            =   0.5
input_len       =   800   
output_len      =   200  
 
MODEL_DIRECTORY = "store/model.ckpt_2"
LOGS_DIRECTORY = "store/model.ckpt_gan_2"   # loss function updated

def trainGAN():

    # initilize the graph
    tf.reset_default_graph() 

    x_vector = tf.placeholder(shape=[None,input_len,input_len,10,1],dtype=tf.float32) 
    T_vector = tf.placeholder(shape=[None,output_len,output_len,10,1],dtype=tf.float32) 
    
    # compute mean and std by encoder and convert to z-latent   
    z_latent, mean, stddev = VMG.encoder(x_vector, phase_train=True, reuse=False, training=False)       

    # compute out-image resulting from z-latent     
    out_image = VMG.generater(z_latent,batch_size, phase_train=True, reuse=False,training=False) 
 
    d_true = VMG.discriminator(T_vector, phase_train=True, reuse=False)
    d_fake = VMG.discriminator(out_image, phase_train=True, reuse=True)
    
    # Compute the encoder, decoder loss and discriminor
    with tf.name_scope("GLOSS"):
#        g_loss = -tf.reduce_mean(tf.log(d_fake))
        g_loss_g = -tf.reduce_mean(tf.log(d_fake))
        g_loss_f = tf.reduce_mean(tf.nn.l2_loss(T_vector - out_image)) /output_len**3
        g_loss_e = tf.reduce_mean(0.5 * (tf.square(mean) + tf.square(stddev) - 2.0 * tf.log(stddev + 1e-8) - 1.0))
        
        g_loss = 0.9 * g_loss_g + 0.1 * g_loss_f + 5 * g_loss_e  
##        g_loss = 0.9 * g_loss_g + 0.0 * g_loss_f + 5 * g_loss_e    # 1-1
##        g_loss = 0.5* g_loss_g + 0.5 * g_loss_f + 5 * g_loss_e    # 2-1        
##        g_loss = 0.9 * g_loss_g + 0.1 * g_loss_f + 100 * g_loss_e    # 3-1  
##        g_loss = 1.0 * g_loss_g + 0.0 * g_loss_f + 5 * g_loss_e    # 4-1           
#        g_loss = 0.9 * g_loss_g + 0.1 * g_loss_f + 1 * g_loss_e    # 4-1    

        
    sum_g_loss=tf.summary.scalar("GLOSS", g_loss)

    with tf.name_scope("DLOSS"):
        d_loss = -tf.reduce_mean(tf.log(d_true) + tf.log(1-d_fake))
    sum_d_loss=tf.summary.scalar("DLOSS", d_loss)    


    # optimizer defination    
    with tf.name_scope("ADAM"):
        # define a variable that store a writeable tensor value persisting between Session.run calls
        para_discriminator = [var for var in tf.trainable_variables() if any(x in var.name for x in ['wd', 'bd', 'discriminator'])]
        para_generator = [var for var in tf.trainable_variables() if any(x in var.name for x in ['wg', 'bg', 'generator','we', 'be', 'encoder'])]

        batch_d = tf.Variable(0) 
        learning_rate_d = tf.train.exponential_decay(
                    0.0001,                  # 5e-5  Base learning rate.
                    batch_d ,                # Current index into the dataset.
                    100,                     # Decay step.
                    0.99,                    # 0.99 Decay rate.
                    staircase=True)    
        
        batch_g = tf.Variable(0) 
        learning_rate_g = tf.train.exponential_decay(
                    0.001,                 # -5  Base learning rate.
                    batch_g ,               # Current index into the dataset.
                    100,                     # Decay step.
                    0.99,                   # 0.99 Decay rate.
                    staircase=True)    

        # only update the weights for the generator network
        optimizer_op_generator = tf.train.AdamOptimizer(learning_rate=learning_rate_g,beta1=beta).minimize(g_loss,var_list=para_generator,global_step=batch_g)
        # only update the weights for the discriminator network
        optimizer_op_discriminator = tf.train.AdamOptimizer(learning_rate=learning_rate_d,beta1=beta).minimize(d_loss,var_list=para_discriminator,global_step=batch_d)

#    sum_e_LR = tf.summary.scalar('learning_rate_e', learning_rate_e)
    sum_d_LR = tf.summary.scalar('learning_rate_d', learning_rate_d) 
    sum_g_LR = tf.summary.scalar('learning_rate_g', learning_rate_g)

    # Add ops to save and restore all the variables
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer()) 
    
#    saver.restore(sess, MODEL_DIRECTORY)

    # op to write logs to Tensorboard
    summary_writer = tf.summary.FileWriter(LOGS_DIRECTORY, graph=tf.get_default_graph())

    in_,out_data=d.data_input()
    
#    e_sum_merge = tf.summary.merge([sum_e_loss, sum_e_LR])
    d_sum_merge = tf.summary.merge([sum_d_loss, sum_d_LR])
    g_sum_merge = tf.summary.merge([sum_g_loss, sum_g_LR])
    
    for epoch in range(n_epochs):
        
        in_data=d.data_combine(in_)
 
#        vae_loss_R, e_merge = sess.run([vae_loss, e_sum_merge],feed_dict={x_vector:in_data})
        vae_loss_R = sess.run(g_loss_e,feed_dict={x_vector:in_data})  


        discriminator_loss,d_merge = sess.run([d_loss, d_sum_merge],feed_dict={T_vector:out_data, x_vector:in_data})
        generator_loss, g_merge = sess.run([g_loss, g_sum_merge],feed_dict={T_vector:out_data, x_vector:in_data})  
#        generator_loss, g_merge = sess.run([g_loss, g_sum_merge],feed_dict={ x_vector:in_data})  
        
        if np.array(discriminator_loss) <= 0.5: 
            sess.run([optimizer_op_generator],feed_dict={T_vector:out_data, x_vector: in_data})
#            sess.run([optimizer_op_generator],feed_dict={x_vector: in_data})
            
        elif np.array(generator_loss) <= 0.5:
            for k in range(k_step):
                sess.run([optimizer_op_discriminator],feed_dict={T_vector:out_data, x_vector:in_data})
        else:
            sess.run([optimizer_op_generator],feed_dict={T_vector:out_data, x_vector: in_data})
#            sess.run([optimizer_op_generator],feed_dict={ x_vector: in_data})
            for k in range(k_step):
                sess.run([optimizer_op_discriminator],feed_dict={T_vector:out_data, x_vector:in_data})

        print("Epoch:", '%03d,' % (epoch + 1), 
              "ec_loss %.3f, g_loss %.2f, d_loss %.3f" % (vae_loss_R, np.array(generator_loss) , np.array(discriminator_loss)))

        # Write logs at every iteration
#        summary_writer.add_summary(e_merge, epoch)
        summary_writer.add_summary(d_merge, epoch)        
        summary_writer.add_summary(g_merge, epoch)        
   
        if epoch % save_period == 0:
            save_path = saver.save(sess, MODEL_DIRECTORY) 
            print("Model updated and saved in file: %s" % save_path)

    print("Optimization Finished!")
    print('Run `tensorboard --logdir=%s` to see the results.' % LOGS_DIRECTORY)

if __name__ == '__main__':
    
    start_time = time.time()
    
    trainGAN()
    
    end_time = time.time()
    
    print("This sampling run took %5.4f seconds." % (end_time - start_time))  

